# ------------------------------------------------------------------------------
# Fits GBM
# From: stress_test_medicare repo <https://gitlab.com/labsysmed/zolab-projects/stress_test_medicare/-/blob/master/code/03_analysis/01_build-models/03_fit_gbm.R>
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 04_fit_gbm.sh {outcome} {split} {restriction}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(testit) # assert()
library(optparse)
library(here)
library(xgboost)

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--outcome", type = "character"),
  make_option("--split", type = "character"),
  make_option("--restriction", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
cohort_dir <- file.path(paths$modeling$dir, "cohorts", arg_list$split)
features_dir <- file.path(paths$features$dir, arg_list$split)
tuning_dir <- file.path(
  paths$modeling$dir, "tuning", arg_list$split, arg_list$restriction
)
save_dir <- file.path(
  paths$modeling$dir, "models", arg_list$split, arg_list$restriction
)
build_models_dir <- here("code", "03_analysis", "01_build_models")

# Load Data --------------------------------------------------------------------
message("Loading data...")
config <- read_yaml(
  file.path(build_models_dir, "model_config", glue("{arg_list$outcome}.yml"))
)
ds <- config$downsample
ids <- readRDS(file.path(cohort_dir, glue("train_cohort_{ds}.rds")))
x <- readRDS(file.path(features_dir, glue("train_features_{ds}.rds")))

if(arg_list$restriction == "dropcc"){
  # drop chief complaint features
  keep_feats <- which(!grepl("ed_enc_t0d", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "justcc"){
  keep_feats <- which(grepl("ed_enc_t0d", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "dem"){
  keep_feats <- which(grepl("dem_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "enc"){
  keep_feats <- which(grepl("enc_", colnames(x)) & !grepl("_cc_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "dia"){
  keep_feats <- which(grepl("dia_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "lab"){
  keep_feats <- which(grepl("lab_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "lvs"){
  keep_feats <- which(grepl("lvs_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "med"){
  keep_feats <- which(grepl("med_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "prc"){
  keep_feats <- which(grepl("prc_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "represent"){
  representative_vars <- readRDS(paths$analysis$representative_vars)
  keep_feats <- which(colnames(x) %in% representative_vars)
  x <- x[, keep_feats]
}

tuning_path <- file.path(
  tuning_dir,
  glue("gbm__{target}__{population}.rds", .envir = config)
)
tuning <- readRDS(tuning_path)

# Subset Data ------------------------------------------------------------------
message("Subsetting data...")
keep_pop <- switch(config$population,
  all = rep(TRUE, nrow(ids)),
  tested = ids$test_010_day == TRUE,
  untested = ids$test_010_day == FALSE
)
keep_pop <- keep_pop & !ids$exclude_modeling
keep <- which(keep_pop)

x <- x[keep, ]
ids <- ids[keep, ]

# Tuning -----------------------------------------------------------------------
message("Identifying best parameters...")
best_params <- tuning %>%
  unnest() %>%
  top_n(1, -logloss) %>%
  top_n(1, max_depth) %>%
  top_n(1, subsample) %>%
  top_n(1, colsample_bytree) %>%
  select(-logloss, everything())

message("Best params:")
iwalk(best_params, ~ message(.y, " = ", .x))

# Fit --------------------------------------------------------------------------
message("Fitting gbm...")
gbm <- xgboost(
  data = x,
  label = ids[[config$target]],
  eta = best_params$eta,
  num_iterations = best_params$num_iterations,
  max_depth = best_params$max_depth,
  subsample = best_params$subsample,
  colsample_bytree = best_params$colsample_bytree,
  objective = "binary:logistic",
  nthread = n_distinct(ids$train_fold),
  nrounds = 10000L,
  early_stopping_rounds = 20L,
  verbose = 1
)

# Save Data --------------------------------------------------------------------
gbm_path <- file.path(
  save_dir,
  glue("gbm__{target}__{population}.rds", .envir = config)
)
message("Saving ", arg_list$outcome, " GBM to ", gbm_path)
xgb.save(gbm, gbm_path)

# Done -------------------------------------------------------------------------
message("Done.")
